{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tsZybAeJT5My"
   },
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "bGSlM5fmT5M1",
    "outputId": "2af3a12b-8736-4c9f-abe8-4d1f7a1c45e5"
   },
   "outputs": [],
   "source": [
    "# For Colab only!\n",
    "\n",
    "try:\n",
    "  # %tensorflow_version only exists in Colab.\n",
    "  %tensorflow_version 2.x\n",
    "except Exception:\n",
    "  pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cell_style": "split",
    "colab": {},
    "colab_type": "code",
    "id": "0XwG0znET5M5"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "cell_style": "split",
    "colab": {},
    "colab_type": "code",
    "id": "BQl-XD9UT5M9"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torchvision import datasets, transforms\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "cell_style": "split",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 126
    },
    "colab_type": "code",
    "id": "Jo8WpQW9T5NA",
    "outputId": "83af30a7-8219-4baa-b712-23d964e87931",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.0\n",
      "WARNING:tensorflow:From <ipython-input-3-2e82a26757ae>:2: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.config.list_physical_devices('GPU')` instead.\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)\n",
    "print(tf.test.is_gpu_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "cell_style": "split",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "zKAqsnciT5NF",
    "outputId": "f7e01863-388c-4de0-b0b6-c4d77f6c620a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4.0\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "i-PslWs3T5NI"
   },
   "source": [
    "### Data Loading\n",
    "MINST data set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CshDQz-RT5NJ"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "batch_size=32\n",
    "learning_rate=0.01\n",
    "epochs=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "cell_style": "split",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "mJT3QyR4T5NP",
    "outputId": "eadf7dfc-362f-408c-e35a-6a9150412d1d"
   },
   "outputs": [],
   "source": [
    "(x, y),(x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "\n",
    "ds_train = tf.data.Dataset.from_tensor_slices((x,y))\n",
    "ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
    "\n",
    "def preprocess(x, y):\n",
    "  x = (tf.cast(x, tf.float32)/255)-0.1307\n",
    "  y = tf.cast(y, tf.int32)\n",
    "#   y = tf.one_hot(y,depth=10)   \n",
    "  return x, y\n",
    "\n",
    "ds_train = ds_train.map(preprocess).shuffle(1000).batch(batch_size)\n",
    "ds_test = ds_test.map(preprocess).shuffle(1000).batch(batch_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "cell_style": "split",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 403,
     "referenced_widgets": [
      "a783de3fb1df4a759c4a7a6ab555ff1d",
      "0da4f96ecec6439a817f042d572d5e99",
      "44923357d143465ea3103974c355755b",
      "cbe3c769b91549e0ab5e15bd6fb93fb6",
      "f5faf9b762e4420989a98aab8d32f597",
      "281848dbd3c549bcb74c2363b79741c7",
      "11845a8cd886472e8c318006c5f1ea9b",
      "656204aaabc24c7bb6cc56fdeb0349a2",
      "7d5870d79704417db9285c8375a913d6",
      "bb4f5574e0f44a968947cd445f379fe2",
      "02fe7e2fd1ad450aa443e79b40021f9a",
      "e061bb86760149febe95e8dc87e5bd6e",
      "25286640a78548929a48756a80280f82",
      "4c8d0bab7a1742ebb2c4ba0584185d9b",
      "9b96bcf1413d4cdfa5b282a597bbcdca",
      "a80d98e195734ea0a574499a67a7633b",
      "f07590c929ae410aa1ca9ef9d2e62b92",
      "0ff34bcf2a684c618e9bd55127c25b8f",
      "8b3da754dc994ce19c7610a41db3c196",
      "f753962277694dbeae6895f3d4e840e3",
      "7bfe437215e647c3b59d25474d9e274a",
      "0ff3dac3b9f74a0e8bb3b3f80c96157b",
      "ada1c65efd2442e29d472e2548649179",
      "717e3dbfa276440d925bd36b322234bd",
      "0da98dec9f2747888a3deba8cf51541a",
      "30b0963345e449fab2818d8da4956a00",
      "e2ea136c67c84594afac559b0b1ed5b5",
      "6333efb6efa242f0a4440d9556c47894",
      "2cc4f906e262464da4ad80404e251556",
      "737540f756f64669822b9aa1571724a8",
      "6541d5e97607431db0b84aa5367e0c7b",
      "6b66b888fcc64caabc3d1a04866b0b51"
     ]
    },
    "colab_type": "code",
    "id": "Hhc1Ik_iT5NS",
    "outputId": "317427f6-37ea-46b6-8ef6-69a5d81326bc",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ])),\n",
    "    batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "U0jZwxEDT5NV"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "cell_style": "split",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "phBTsVL4T5Nb",
    "outputId": "f4a88c65-95f2-481d-fdfe-181a11b06a60"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>\n",
      "(32, 28, 28) (32,)\n"
     ]
    }
   ],
   "source": [
    "print(type(ds_test))\n",
    "image, label = next(iter(ds_test))\n",
    "print(image.shape, label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "cell_style": "split",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "tZ0uInfKT5Ng",
    "outputId": "47313afd-a621-46d6-ad68-911abffeabca",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.utils.data.dataloader.DataLoader'>\n",
      "torch.Size([32, 1, 28, 28]) torch.Size([32])\n"
     ]
    }
   ],
   "source": [
    "print(type(train_loader))\n",
    "[image, label] = next(iter(train_loader))\n",
    "print(image.shape, label.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vDX0rbORT5Nr"
   },
   "source": [
    "### CNN\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fkmqyFxdU2fZ"
   },
   "source": [
    "Tensorflow\n",
    "\n",
    "[b, 28, 28 1]  with kernel [3, 1, 3, 3] padding = \"same\", stride = 1 -> [b, 14, 14, 3]\n",
    "\n",
    "> NOTE:\n",
    ">\n",
    "> Here In `layers.Conv2D` parameters: filters = 3, kernel_size = (3,3), strides = (1,1)\n",
    "\n",
    "[b, 14, 14, 3] -> maxpooling -> [b, 7, 7, 3]\n",
    "\n",
    "[b, 14, 14, 3]  with kernel [6, 3, 3, 3] padding = \"same\", stride = 2 -> [b, 7, 7, 6]\n",
    "\n",
    "[b, 7, 7, 6]  flatten() -> [b, 7*7*6]\n",
    "\n",
    "[b, 7*7*6] -> Dense -> [b, 10]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cell_style": "center",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "-h_iJsUiT5Nr",
    "outputId": "1e84d763-ce6d-4ed7-ce4d-5f5555d46b0e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"cnn_model_2\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "sequential_2 (Sequential)    multiple                  3148      \n",
      "=================================================================\n",
      "Total params: 3,148\n",
      "Trainable params: 3,148\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "epoch:0, step:0 loss:2.328885078430176\n",
      "accuracy:  0.1409\n",
      "epoch:0, step:100 loss:0.14413821697235107\n",
      "accuracy:  0.8648\n",
      "epoch:0, step:200 loss:0.28817737102508545\n",
      "accuracy:  0.9059\n",
      "epoch:0, step:300 loss:0.17282797396183014\n",
      "accuracy:  0.9134\n",
      "epoch:0, step:400 loss:0.0806659460067749\n",
      "accuracy:  0.9312\n",
      "epoch:0, step:500 loss:0.06717875599861145\n",
      "accuracy:  0.9398\n",
      "epoch:0, step:600 loss:0.2847217917442322\n",
      "accuracy:  0.9508\n",
      "epoch:0, step:700 loss:0.22499491274356842\n",
      "accuracy:  0.9481\n",
      "epoch:0, step:800 loss:0.16353292763233185\n",
      "accuracy:  0.9512\n",
      "epoch:0, step:900 loss:0.06671904027462006\n",
      "accuracy:  0.9484\n",
      "epoch:0, step:1000 loss:0.09817676246166229\n",
      "accuracy:  0.9579\n",
      "epoch:0, step:1100 loss:0.019231852144002914\n",
      "accuracy:  0.9481\n",
      "epoch:0, step:1200 loss:0.2960127592086792\n",
      "accuracy:  0.9593\n",
      "epoch:0, step:1300 loss:0.14271831512451172\n",
      "accuracy:  0.955\n",
      "epoch:0, step:1400 loss:0.16546659171581268\n",
      "accuracy:  0.9663\n",
      "epoch:0, step:1500 loss:0.020800674334168434\n",
      "accuracy:  0.9644\n",
      "epoch:0, step:1600 loss:0.31070148944854736\n",
      "accuracy:  0.9661\n",
      "epoch:0, step:1700 loss:0.03529775142669678\n",
      "accuracy:  0.9642\n",
      "epoch:0, step:1800 loss:0.2327214628458023\n",
      "accuracy:  0.9489\n",
      "epoch:1, step:0 loss:0.024368826299905777\n",
      "accuracy:  0.9667\n",
      "epoch:1, step:100 loss:0.044243842363357544\n",
      "accuracy:  0.965\n",
      "epoch:1, step:200 loss:0.2215174436569214\n",
      "accuracy:  0.9625\n",
      "epoch:1, step:300 loss:0.10046914964914322\n",
      "accuracy:  0.9666\n",
      "epoch:1, step:400 loss:0.06817823648452759\n",
      "accuracy:  0.9673\n",
      "epoch:1, step:500 loss:0.04892420023679733\n",
      "accuracy:  0.9693\n",
      "epoch:1, step:600 loss:0.0573093518614769\n",
      "accuracy:  0.9748\n",
      "epoch:1, step:700 loss:0.030420873314142227\n",
      "accuracy:  0.9739\n",
      "epoch:1, step:800 loss:0.173701673746109\n",
      "accuracy:  0.9734\n",
      "epoch:1, step:900 loss:0.09255193173885345\n",
      "accuracy:  0.9735\n",
      "epoch:1, step:1000 loss:0.4066876173019409\n",
      "accuracy:  0.9718\n",
      "epoch:1, step:1100 loss:0.038255561143159866\n",
      "accuracy:  0.9739\n",
      "epoch:1, step:1200 loss:0.08493136614561081\n",
      "accuracy:  0.9739\n",
      "epoch:1, step:1300 loss:0.051499076187610626\n",
      "accuracy:  0.969\n",
      "epoch:1, step:1400 loss:0.14321663975715637\n",
      "accuracy:  0.9752\n",
      "epoch:1, step:1500 loss:0.02896244078874588\n",
      "accuracy:  0.9715\n",
      "epoch:1, step:1600 loss:0.1450221985578537\n",
      "accuracy:  0.9728\n",
      "epoch:1, step:1700 loss:0.02648809365928173\n",
      "accuracy:  0.9718\n",
      "epoch:1, step:1800 loss:0.02126728743314743\n",
      "accuracy:  0.9695\n",
      "epoch:2, step:0 loss:0.013731674291193485\n",
      "accuracy:  0.9723\n",
      "epoch:2, step:100 loss:0.26277053356170654\n",
      "accuracy:  0.9702\n",
      "epoch:2, step:200 loss:0.10695470869541168\n",
      "accuracy:  0.9728\n",
      "epoch:2, step:300 loss:0.021643271669745445\n"
     ]
    }
   ],
   "source": [
    "class CNN_model(keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "        self.model = keras.Sequential(\n",
    "            [layers.Conv2D(filters=3, kernel_size=(3,3), strides=(1,1),padding=\"same\"),\n",
    "            layers.MaxPool2D(pool_size=(2,2)),\n",
    "            layers.ReLU(),\n",
    "            layers.Conv2D(6,(3,3),(2,2),\"same\"),\n",
    "            layers.ReLU(),\n",
    "            layers.Flatten(),\n",
    "            layers.Dense(10)]\n",
    "            )\n",
    "    \n",
    "    def call(self,x):\n",
    "        x = self.model(x)\n",
    "        \n",
    "        return x\n",
    "    \n",
    "model = CNN_model()\n",
    "model.build(input_shape = (None,28,28,1))\n",
    "model.summary()\n",
    "optimizer = tf.optimizers.Adam(learning_rate)\n",
    "    \n",
    "for epoch in range(epochs):\n",
    "    \n",
    "    for step, (x, y) in enumerate(ds_train):\n",
    "        x = tf.reshape(x, [-1, 28,28,1])\n",
    "        with tf.GradientTape() as tape:            \n",
    "            logits = model(x)\n",
    "            \n",
    "            losses = tf.losses.sparse_categorical_crossentropy(y,logits,from_logits=True)\n",
    "            loss = tf.reduce_mean(losses)\n",
    "            \n",
    "        grads = tape.gradient(loss, model.variables)\n",
    "        \n",
    "        optimizer.apply_gradients(zip(grads, model.variables))\n",
    "        \n",
    "        if(step%100==0):\n",
    "            print(\"epoch:{}, step:{} loss:{}\".\n",
    "                  format(epoch, step, loss.numpy()))\n",
    "            \n",
    "            \n",
    "#             test accuracy: \n",
    "            total_correct = 0\n",
    "            total_num = 0\n",
    "            \n",
    "            for x_test, y_test in ds_test:\n",
    "                x_test = tf.reshape(x_test, [-1, 28,28,1])\n",
    "                y_pred = tf.argmax(model(x_test),axis=1)\n",
    "                y_pred = tf.cast(y_pred, tf.int32)\n",
    "                correct = tf.cast((y_pred == y_test), tf.int32)\n",
    "                correct = tf.reduce_sum(correct)\n",
    "                \n",
    "                total_correct += int(correct)\n",
    "                total_num += x_test.shape[0]\n",
    "        \n",
    "            \n",
    "            accuracy = total_correct/total_num\n",
    "            print('accuracy: ', accuracy)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "zI7zP4RHZRpM"
   },
   "source": [
    "PyTorch\n",
    "\n",
    "[b, 1, 28, 28]  with kernel [3, 1, 3, 3] padding = 1, stride = 1 -> [b, 3, 28, 28]\n",
    "\n",
    "> NOTE:\n",
    ">\n",
    "> Here In `nn.Conv2d` parameters: filters = 3, kernel_size = (3,3), strides = (2,2)\n",
    "\n",
    "[b, 3, 28, 28] with `nn.MaxPool2d(kernel_size=2)` -> [b, 3, 14 ,14 ] \n",
    "\n",
    "[b, 3, 7, 7]  with kernel [6, 3, 3, 3] padding = 1, stride = 2 -> [b, 6, 7, 7]\n",
    "\n",
    "[b, 6, 7, 7]  flatten() -> [b, 6 * 7 * 7]\n",
    "\n",
    "[b, 6 * 7 * 7] -> Dense -> [b, 10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "cell_style": "center",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "mHUWourvT5Nx",
    "outputId": "56147d91-26d9-42c6-e0c3-5107822e595f"
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 7.93 GiB total capacity; 0 bytes already allocated; 24.50 MiB free; 0 bytes reserved in total by PyTorch)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-18-224ecfdbf67a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cuda:0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0mnetwork\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCNN_NN\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m optimizer = torch.optim.Adam(network.parameters(),\n\u001b[1;32m     23\u001b[0m                             lr=learning_rate)\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    423\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    424\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 425\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    426\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    427\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_backward_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    200\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 201\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    203\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    200\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 201\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    203\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    221\u001b[0m                 \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    222\u001b[0m                 \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 223\u001b[0;31m                     \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    224\u001b[0m                 \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    225\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m    421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    422\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 423\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    424\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    425\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 7.93 GiB total capacity; 0 bytes already allocated; 24.50 MiB free; 0 bytes reserved in total by PyTorch)"
     ]
    }
   ],
   "source": [
    "class CNN_NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "        self.model = nn.Sequential(\n",
    "            nn.Conv2d(in_channels=1,out_channels=3,kernel_size=3,stride=1,padding=1),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Conv2d(3,6,kernel_size=3,stride=2,padding=1),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(6*7*7,10)\n",
    "            )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.model(x)\n",
    "        \n",
    "        return x\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "network = CNN_NN().to(device)        \n",
    "optimizer = torch.optim.Adam(network.parameters(),\n",
    "                            lr=learning_rate)\n",
    "criteon = torch.nn.CrossEntropyLoss().to(device)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    \n",
    "    for step, (x, y) in enumerate(train_loader):\n",
    "        x = x.reshape(-1,1,28,28)\n",
    "        \n",
    "        x, y = x.to(device), y.to(device)\n",
    "        \n",
    "        logits = network(x)\n",
    "        loss = criteon(logits, y)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if(step%100 == 0):\n",
    "            print(\"epoch:{}, step:{}, loss:{}\".\n",
    "                  format(epoch, step, loss.item()))\n",
    "        \n",
    "#             test accuracy\n",
    "            total_correct = 0\n",
    "            total_num = 0    \n",
    "\n",
    "            for x_test, y_test in test_loader:\n",
    "                    x_test = x_test.reshape(-1,1,28,28)\n",
    "                    x_test, y_test = x_test.to(device), y_test.to(device)\n",
    "\n",
    "                    y_pred = network(x_test)\n",
    "                    y_pred = torch.argmax(y_pred, dim = 1)\n",
    "                    correct = y_pred == y_test\n",
    "                    correct = correct.sum()\n",
    "\n",
    "                    total_correct += correct\n",
    "                    total_num += x_test.shape[0]\n",
    "\n",
    "            acc = total_correct.float()/total_num\n",
    "            print(\"accuracy: \", acc.item())\n",
    "                \n",
    "                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NtxvEuzZaU85"
   },
   "outputs": [],
   "source": [
    "# OUT SIZE TEST BLOCK FOR PYTORCH\n",
    "\n",
    "x = torch.rand(1,1,28,28)\n",
    "layer1 = nn.Conv2d(1,3,kernel_size=3,stride=1,padding=1)\n",
    "layer2 = nn.MaxPool2d(kernel_size=2)\n",
    "layer3 = nn.Conv2d(3,6,kernel_size=3,stride=2,padding=1)\n",
    "\n",
    "out = layer1(x)\n",
    "print(out.shape)\n",
    "\n",
    "out = layer2(out)\n",
    "\n",
    "print(out.shape)\n",
    "\n",
    "out = layer3(out)\n",
    "print(out.shape)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "CNN_nn.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "02fe7e2fd1ad450aa443e79b40021f9a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "IntProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "IntProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_4c8d0bab7a1742ebb2c4ba0584185d9b",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_25286640a78548929a48756a80280f82",
      "value": 1
     }
    },
    "0da4f96ecec6439a817f042d572d5e99": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0da98dec9f2747888a3deba8cf51541a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_e2ea136c67c84594afac559b0b1ed5b5",
       "IPY_MODEL_6333efb6efa242f0a4440d9556c47894"
      ],
      "layout": "IPY_MODEL_30b0963345e449fab2818d8da4956a00"
     }
    },
    "0ff34bcf2a684c618e9bd55127c25b8f": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0ff3dac3b9f74a0e8bb3b3f80c96157b": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "11845a8cd886472e8c318006c5f1ea9b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "25286640a78548929a48756a80280f82": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "281848dbd3c549bcb74c2363b79741c7": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2cc4f906e262464da4ad80404e251556": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "30b0963345e449fab2818d8da4956a00": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "44923357d143465ea3103974c355755b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "IntProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "IntProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_281848dbd3c549bcb74c2363b79741c7",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_f5faf9b762e4420989a98aab8d32f597",
      "value": 1
     }
    },
    "4c8d0bab7a1742ebb2c4ba0584185d9b": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "6333efb6efa242f0a4440d9556c47894": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_6b66b888fcc64caabc3d1a04866b0b51",
      "placeholder": "​",
      "style": "IPY_MODEL_6541d5e97607431db0b84aa5367e0c7b",
      "value": "8192it [00:00, 63151.30it/s]"
     }
    },
    "6541d5e97607431db0b84aa5367e0c7b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "656204aaabc24c7bb6cc56fdeb0349a2": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "6b66b888fcc64caabc3d1a04866b0b51": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "717e3dbfa276440d925bd36b322234bd": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "737540f756f64669822b9aa1571724a8": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "7bfe437215e647c3b59d25474d9e274a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "7d5870d79704417db9285c8375a913d6": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_02fe7e2fd1ad450aa443e79b40021f9a",
       "IPY_MODEL_e061bb86760149febe95e8dc87e5bd6e"
      ],
      "layout": "IPY_MODEL_bb4f5574e0f44a968947cd445f379fe2"
     }
    },
    "8b3da754dc994ce19c7610a41db3c196": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "IntProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "IntProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_0ff3dac3b9f74a0e8bb3b3f80c96157b",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_7bfe437215e647c3b59d25474d9e274a",
      "value": 1
     }
    },
    "9b96bcf1413d4cdfa5b282a597bbcdca": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "a783de3fb1df4a759c4a7a6ab555ff1d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_44923357d143465ea3103974c355755b",
       "IPY_MODEL_cbe3c769b91549e0ab5e15bd6fb93fb6"
      ],
      "layout": "IPY_MODEL_0da4f96ecec6439a817f042d572d5e99"
     }
    },
    "a80d98e195734ea0a574499a67a7633b": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "ada1c65efd2442e29d472e2548649179": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "bb4f5574e0f44a968947cd445f379fe2": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "cbe3c769b91549e0ab5e15bd6fb93fb6": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_656204aaabc24c7bb6cc56fdeb0349a2",
      "placeholder": "​",
      "style": "IPY_MODEL_11845a8cd886472e8c318006c5f1ea9b",
      "value": "9920512it [00:00, 24750672.85it/s]"
     }
    },
    "e061bb86760149febe95e8dc87e5bd6e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a80d98e195734ea0a574499a67a7633b",
      "placeholder": "​",
      "style": "IPY_MODEL_9b96bcf1413d4cdfa5b282a597bbcdca",
      "value": "32768it [00:00, 180427.64it/s]"
     }
    },
    "e2ea136c67c84594afac559b0b1ed5b5": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "IntProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "IntProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_737540f756f64669822b9aa1571724a8",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_2cc4f906e262464da4ad80404e251556",
      "value": 1
     }
    },
    "f07590c929ae410aa1ca9ef9d2e62b92": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_8b3da754dc994ce19c7610a41db3c196",
       "IPY_MODEL_f753962277694dbeae6895f3d4e840e3"
      ],
      "layout": "IPY_MODEL_0ff34bcf2a684c618e9bd55127c25b8f"
     }
    },
    "f5faf9b762e4420989a98aab8d32f597": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "f753962277694dbeae6895f3d4e840e3": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_717e3dbfa276440d925bd36b322234bd",
      "placeholder": "​",
      "style": "IPY_MODEL_ada1c65efd2442e29d472e2548649179",
      "value": "1654784it [00:00, 4611594.95it/s]"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
